import pickle
from PIL import Image
import random
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from dataloader.split_train_test_video import *


class JointDataset(Dataset):

    def __init__(self, flow_dict, rgb_dict, flow_dir, rgb_dir, num_frames, mode):
        self.flow_keys = list(flow_dict.keys())
        self.flow_values = list(flow_dict.values())
        self.rgb_keys = list(rgb_dict.keys())
        self.rgb_values = list(rgb_dict.values())
        self.flow_dir = flow_dir
        self.rgb_dir = rgb_dir
        self.mode = mode
        self.num_frames = num_frames
        self.img_rows = 224
        self.img_cols = 224
        self.rgb_transform = transforms.Compose([transforms.RandomCrop(224),
                                                 transforms.RandomHorizontalFlip(),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                      std=[0.229, 0.224, 0.225])])
        self.rgb_val_transform = transforms.Compose([transforms.Resize([224, 224]),
                                                     transforms.ToTensor(),
                                                     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                          std=[0.229, 0.224, 0.225])])
        self.flow_transform = transforms.Compose([transforms.Resize([224, 224]),
                                                  transforms.ToTensor()])

    def stack_optical_flow(self):
        name = 'v_' + self.flow_video
        u = self.flow_dir + 'u/' + name
        v = self.flow_dir + 'v/' + name

        #flow = torch.FloatTensor(self.ncrops, 2 * self.in_channel, self.img_rows, self.img_cols)
        flow = torch.FloatTensor(2*self.num_frames, self.img_rows, self.img_cols)
        i = int(self.flow_clips_idx)

        for j in range(self.num_frames):
            idx = i + j
            idx = str(idx)
            frame_idx = 'frame' + idx.zfill(6)
            h_image = u + '/' + frame_idx + '.jpg'
            v_image = v + '/' + frame_idx + '.jpg'

            imgH = (Image.open(h_image))
            imgV = (Image.open(v_image))

            H = self.flow_transform(imgH)
            V = self.flow_transform(imgV)

            flow[2*(j-1), :, :] = H
            flow[2*(j-1)+1, :, :] = V

            #if self.ncrops:
            #    flow[:, 2 * (j - 1), :, :] = H.squeeze()
            #    flow[:, 2 * (j - 1) + 1, :, :] = V.squeeze()
            #else:
            #    flow[:, 2 * (j - 1), :, :] = H
            #    flow[:, 2 * (j - 1) + 1, :, :] = V

            imgH.close()
            imgV.close()

        return flow  #flow.squeeze()

    def load_rgb_image(self, video_name, index):
        if video_name.split('_')[0] == 'HandstandPushups':
            n, g = video_name.split('_', 1)
            name = 'HandStandPushups_' + g
            path = self.rgb_dir + 'HandStandPushups' + '/v_' + name + '/image_'
        else:
            path = self.rgb_dir + video_name.split('_')[0] + '/v_' + video_name + '/image_'

        img = Image.open(path + str(index).zfill(5) + '.jpg')
        if self.mode == 'train':
            transformed_img = self.rgb_transform(img)
        elif self.mode == 'val':
            transformed_img = self.rgb_val_transform(img)
        else:
            raise ValueError('There are only train and val mode')
        img.close()

        return transformed_img

    def __len__(self):
        return len(self.flow_keys)

    def __getitem__(self, idx):

        if self.mode == 'train':
            self.flow_video, flow_nb_clips = self.flow_keys[idx].split('-')
            self.flow_clips_idx = random.randint(1, int(flow_nb_clips))

            rgb_video_name, rgb_nb_clips = self.rgb_keys[idx].split('-')
            rgb_nb_clips = int(rgb_nb_clips)
            rgb_clips = [random.randint(1, int(rgb_nb_clips / 3)),
                         random.randint(int(rgb_nb_clips / 3), int(rgb_nb_clips * 2 / 3)),
                         random.randint(int(rgb_nb_clips * 2 / 3), rgb_nb_clips + 1)]

        elif self.mode == 'val':
            self.flow_video, self.flow_clips_idx = self.flow_keys[idx].split('-')

            rgb_video_name, rgb_index = self.rgb_keys[idx].split('-')
            rgb_val_index = abs(int(rgb_index))
        else:
            raise ValueError('There are only train and val mode')

        if self.flow_values[idx] != self.rgb_values[idx]:
            raise Exception('rgb label ({}) must equal to flow label ({})'.format(self.rgb_values[idx], self.flow_values[idx]))
        label = self.flow_values[idx]
        label = int(label) - 1

        flow = self.stack_optical_flow()

        if self.mode == 'train':
            rgb = {}
            for i in range(len(rgb_clips)):
                key = 'img' + str(i)
                temp_index = rgb_clips[i]
                rgb[key] = self.load_rgb_image(rgb_video_name, temp_index)

            sample = (flow, rgb, label)

        elif self.mode == 'val':
            rgb = self.load_rgb_image(rgb_video_name, rgb_val_index)
            sample = (self.flow_video, flow, rgb, label)

        else:
            raise ValueError('There are only train and val mode')

        return sample


class JointDataloader:

    def __init__(self, batch_size, num_workers, num_frames, flow_path, rgb_path, ucf_list, ucf_split):

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.frame_count = {}
        self.num_frames = num_frames  # number of frames for OF, e.g. 10 frames = 20 OF channels
        self.flow_path = flow_path
        self.rgb_path = rgb_path
        self.flow_dict_train = {}
        self.rgb_dict_train = {}
        self.flow_dict_test_idx = {}
        self.rgb_dict_test_idx = {}
        # split the training and testing videos
        splitter = UCF101_splitter(path=ucf_list, split=ucf_split)
        self.train_video, self.test_video = splitter.split_video()

    def load_frame_count(self):
        #print '==> Loading frame number of each video'
        with open('dataloader/dic/frame_count.pickle', 'rb') as file:
            dic_frame = pickle.load(file)
        file.close()

        for line in dic_frame:
            videoname = line.split('_', 1)[1].split('.', 1)[0]
            n, g = videoname.split('_', 1)
            if n == 'HandStandPushups':
                videoname = 'HandstandPushups_' + g
            self.frame_count[videoname] = dic_frame[line]

    def run(self):
        self.load_frame_count()
        self.get_training_dict()
        self.get_flow_val_sample()
        self.get_rgb_val_sample()
        train_loader = self.get_train_loader()
        val_loader = self.get_val_loader()
        return train_loader, val_loader, self.test_video

    def get_training_dict(self):
        for video in self.train_video:
            nb_clips = self.frame_count[video] - 10 + 1
            key = video + '-' + str(nb_clips)
            self.flow_dict_train[key] = self.train_video[video]
            self.rgb_dict_train[key] = self.train_video[video]

    def get_flow_val_sample(self):
        for video in self.test_video:
            sampling_interval = int((self.frame_count[video]-10+1)/19)
            for index in range(19):
                clip_idx = index * sampling_interval
                key = video + '-' + str(clip_idx+1)
                self.flow_dict_test_idx[key] = self.test_video[video]

    def get_rgb_val_sample(self):
        for video in self.test_video:
            nb_frame = self.frame_count[video]-10+1
            interval = int(nb_frame/19)
            for i in range(19):
                frame = i * interval
                key = video + '-' + str(frame+1)
                self.rgb_dict_test_idx[key] = self.test_video[video]

    def get_train_loader(self):
        train_dataset = JointDataset(self.flow_dict_train, self.rgb_dict_train,
                                     self.flow_path, self.rgb_path,
                                     num_frames=self.num_frames, mode='train')

        print('\n==> Training data: {} videos'.format(len(train_dataset)))
        print('   Flow shape: {}     RGB shape: {}'.format(train_dataset[1][0].size(), train_dataset[1][1]['img1'].size()))

        train_loader = DataLoader(train_dataset,
                                  batch_size=self.batch_size,
                                  shuffle=True,
                                  num_workers=self.num_workers,
                                  drop_last=True,
                                  pin_memory=True)
        return train_loader

    def get_val_loader(self):
        val_dataset = JointDataset(self.flow_dict_test_idx, self.rgb_dict_test_idx,
                                   self.flow_path, self.rgb_path,
                                   num_frames=self.num_frames, mode='val')

        print('\n==> Validation data: {} frames'.format(len(val_dataset)))
        print('   Flow shape: {}     RGB shape: {}'.format(val_dataset[1][1].size(), val_dataset[1][2].size()))

        val_loader = DataLoader(val_dataset,
                                batch_size=self.batch_size,
                                shuffle=False,
                                num_workers=self.num_workers)
        return val_loader


if __name__ == '__main__':
    pass
    #

    # train_loader, val_loader, test_videos = joint_dataloader.run()
    #
    # print('\n')
    # #flow, rgb_dict, label = next(iter(train_loader))
    # #print('Flow: {}'.format(flow.shape))
    # #print('RGB: {}'.format(rgb_dict['img1'].shape))
    # #print('label: {}'.format(label.shape))
    # keys, flow, rgb, label = next(iter(val_loader))